-
Notifications
You must be signed in to change notification settings - Fork 528
[MRG] Added argument for warmstart of dual vectors in Sinkhorn-based methods in ot.bregman
#437
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
ot.bregman
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## master #437 +/- ##
==========================================
+ Coverage 94.69% 94.70% +0.01%
==========================================
Files 24 24
Lines 6593 6608 +15
==========================================
+ Hits 6243 6258 +15
Misses 350 350 |
ot.bregman
ot.bregman
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small change in position of parameter wramstart and few doc changes.
Thanks for the PR @6Ulm
ot/bregman.py
Outdated
@@ -93,6 +93,10 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, | |||
those function for specific parameters | |||
numItermax : int, optional | |||
Max number of iterations | |||
warmstart: tuple of arrays, shape (dim_a, dim_b), optional | |||
Initialization of dual vectors. If provided, | |||
the dual vectors must be already taken the logarithm, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the dual vectors must be already taken the logarithm, | |
the dual potentails should be given (that is the logarithm of the u,v sinkhorn scaling vectors) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also remove the following line
ot/bregman.py
Outdated
@@ -24,7 +24,7 @@ | |||
from .backend import get_backend | |||
|
|||
|
|||
def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, | |||
def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, warmstart=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move the warmstart after the stpthr
ot/bregman.py
Outdated
@@ -154,35 +158,35 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, | |||
""" | |||
|
|||
if method.lower() == 'sinkhorn': | |||
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, | |||
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
ot/bregman.py
Outdated
stopThr=stopThr, verbose=verbose, | ||
log=log, warn=warn, | ||
**kwargs) | ||
else: | ||
raise ValueError("Unknown method '%s'." % method) | ||
|
||
|
||
def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, | ||
def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, warmstart=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
ot/bregman.py
Outdated
@@ -407,6 +415,10 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, | |||
Regularization term >0 | |||
numItermax : int, optional | |||
Max number of iterations | |||
warmstart: tuple of arrays, shape (dim_a, dim_b), optional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also here
ot/bregman.py
Outdated
@@ -546,7 +561,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, | |||
return u.reshape((-1, 1)) * K * v.reshape((1, -1)) | |||
|
|||
|
|||
def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, | |||
def sinkhorn_log(a, b, M, reg, numItermax=1000, warmstart=None, stopThr=1e-9, verbose=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here move warmstart please
ot/bregman.py
Outdated
@@ -746,7 +771,7 @@ def get_logT(u, v): | |||
return nx.exp(get_logT(u, v)) | |||
|
|||
|
|||
def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, | |||
def greenkhorn(a, b, M, reg, numItermax=10000, warmstart=None, stopThr=1e-9, verbose=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also here
ot/bregman.py
Outdated
@@ -789,6 +814,10 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, | |||
Regularization term >0 | |||
numItermax : int, optional | |||
Max number of iterations | |||
warmstart: tuple of arrays, shape (dim_a, dim_b), optional | |||
Initialization of dual vectors. If provided, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also here,
ot/bregman.py
Outdated
@@ -2857,7 +2899,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, | |||
|
|||
|
|||
def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', | |||
numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False, | |||
numIterMax=10000, warmstart=None, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
ot.bregman
ot.bregman
Types of changes
Add argument of warmstart/initialization of dual vectors to Sinkhorn-based methods.
Motivation and context / Related issue
The current Sinkhorn solvers
sinkhorn
,empirical_sinkhorn
,empirical_sinkhorn_divergence
use default intialization of dual vectors, and does not allow to take other initializations as input. This can be unnecessarily inefficient, in terms of time and computational cost, if meaningful initializations exist.How has this been tested (if it applies)
These changes have been tested on toy examples. The script of the tests is added to
test_bregman
.PR checklist